#!/usr/bin/env python3
#
#  Copyright (c) 2023, The OpenThread Authors.
#  All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in the
#     documentation and/or other materials provided with the distribution.
#  3. Neither the name of the copyright holder nor the
#     names of its contributors may be used to endorse or promote products
#     derived from this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#

from itertools import filterfalse
import jinja2
from pathlib import Path
import copy
import yaml
import argparse
import os


def prepare_path(path: str) -> str:
    if type(path) is not str:
        return path

    # Remove any leading '-I' flag
    path = path.removeprefix("-I")

    # Replace back-slashes with forward-slashes
    path = path.replace('\\', '/')

    # Escape spaces
    path = path.replace(' ', '\\ ')

    # Strip outer quotes
    if path.startswith('"') and path.endswith('"'):
        path = path.strip('"')

    # Replace SDK_PATH with SILABS_SDK_DIR
    path = path.replace('(SDK_PATH)', '{SILABS_SDK_DIR}')

    # Redirect OpenThread stack sources to the ot-efr32 openthread submodule #}
    path = path.replace('${SILABS_SDK_DIR}/util/third_party/openthread', '${PROJECT_SOURCE_DIR}/openthread')

    # Redirect PAL sources to the ot-efr32 PAL
    path = path.replace('${SILABS_SDK_DIR}/protocol/openthread/platform-abstraction/efr32',
                        '${PROJECT_SOURCE_DIR}/src/src')

    return path


def is_mbedtls_source(source: str) -> bool:
    r: bool = False
    r |= '${SILABS_SDK_DIR}/util/third_party/mbedtls' in source
    return r


def is_mbedtls_define(define: str) -> bool:
    r: bool = False
    r |= define.startswith("MBEDTLS_")
    r |= define.startswith("PSA_")
    return r


def is_openthread_device_type_define(define: str) -> bool:
    device_types: list[str] = [
        "OPENTHREAD_RADIO",
        "OPENTHREAD_FTD",
        "OPENTHREAD_MTD",
    ]
    return define in device_types


def is_openthread_define(define: str) -> bool:
    r: bool = False
    r &= not is_mbedtls_define(define)
    r &= not is_openthread_device_type_define(define)
    r &= define.startswith("OPENTHREAD_")
    return r


def filter_mbedtls_lib_vars(slc_vars: dict) -> dict:
    '''Filter mbedtls lib vars'''

    # Start with a copy of everything
    mbedtls_lib_vars: dict = copy.deepcopy(slc_vars)

    # Filter includes
    def f(include: str) -> bool:
        keep: bool = False
        keep |= (include == '"autogen"') or (include == '"config"')
        keep |= '${SILABS_SDK_DIR}/hardware' in include
        keep |= '${SILABS_SDK_DIR}/platform' in include
        keep |= '${SILABS_SDK_DIR}/util/third_party/crypto' in include
        if not keep:
            pass
        return keep

    mbedtls_lib_vars['MBEDTLS_INCLUDES'] = list(filter(f, mbedtls_lib_vars['C_CXX_INCLUDES']))

    return mbedtls_lib_vars


def filter_platform_lib_vars(slc_vars: dict) -> dict:
    '''Filter platform lib vars'''

    # Start with a copy of everything
    platform_lib_vars = copy.deepcopy(slc_vars)

    return platform_lib_vars


parser: argparse.ArgumentParser = argparse.ArgumentParser()
parser.add_argument(
    "slc_vars_yaml",
    help="the .yml file generated by slc. This file should contain all the standard SLC template variables")
parser.add_argument("output_dir", help="the output dir for any generated files")
args = parser.parse_args()

# Import slc variables
with open(args.slc_vars_yaml, mode="r") as slc_vars_yaml:
    global slc_vars
    slc_vars = yaml.load(slc_vars_yaml, Loader=yaml.Loader)

# Replace None values with empty lists for easy list concatenation
for k, v in slc_vars.items():
    if not v:
        slc_vars[k] = list()


# Remove mapfile flag from linker options
def f(flag: str) -> bool:
    '''Function for use with filter() to determine if a flag should be kept'''
    keep = True
    keep &= "-Map=" not in flag
    keep &= "--specs=" not in flag
    keep &= "linkerfile.ld" not in flag
    return keep


for key in ["EXT_LD_FLAGS", "EXT_DEBUG_LD_FLAGS"]:
    flags = slc_vars[key]
    slc_vars[key] = list(filter(f, flags))

# Do some general cleaning up of paths
keys_to_prepare = [
    "C_CXX_INCLUDES",
    "ALL_SOURCES",
    "SYS_LIBS",
    "USER_LIBS",
]
for key in keys_to_prepare:
    value = slc_vars[key]
    if type(value) is list:
        for i in range(len(slc_vars[key])):
            path = value[i]
            value[i] = prepare_path(path)
    elif type(value) is dict:
        for k, v in value.items():
            value[k] = prepare_path(v)

# Remove None values
for key in slc_vars:
    if type(slc_vars[key]) is list:
        slc_vars[key] = list(filter(lambda item: item, slc_vars[key]))

# Separate mbedtls sources
slc_vars['MBEDTLS_SOURCES'] = list(filter(is_mbedtls_source, slc_vars['ALL_SOURCES']))
slc_vars['NON_MBEDTLS_SOURCES'] = list(filterfalse(is_mbedtls_source, slc_vars['ALL_SOURCES']))

# ==============================================================================
# Filter defines
# ==============================================================================
defines = slc_vars['C_CXX_DEFINES']

# Separate mbedtls defines
slc_vars['MBEDTLS_DEFINES'] = {d: defines[d] for d in defines if is_mbedtls_define(d)}
for d in slc_vars['MBEDTLS_DEFINES']:
    del defines[d]

# Filter OPENTHREAD device type and remove from C_CXX_DEFINES
slc_vars['OPENTHREAD_DEVICE_TYPE'] = {d: defines[d] for d in defines if is_openthread_device_type_define(d)}
for d in slc_vars['OPENTHREAD_DEVICE_TYPE']:
    del defines[d]

# Filter OPENTHREAD_* defines
slc_vars['OPENTHREAD_DEFINES'] = {d: defines[d] for d in defines if is_openthread_define(d)}
for d in slc_vars['OPENTHREAD_DEFINES']:
    del defines[d]

# filter slc vars
platform_lib_vars = filter_platform_lib_vars(slc_vars)
mbedtls_lib_vars = filter_mbedtls_lib_vars(slc_vars)

# Define CMakeLists.txt template location
script_dir: Path = Path(os.path.dirname(os.path.abspath(__file__)))
repo_root: Path = script_dir.parent
environment = jinja2.Environment(
    loader=jinja2.FileSystemLoader(f"{repo_root}/third_party/silabs/slc/exporter_templates/platform_library"))
platform_lib_template: jinja2.Template = environment.get_template("CMakeLists.txt.jinja")
mbedtls_lib_template: jinja2.Template = environment.get_template("mbedtls.cmake.jinja")

# Render the template with the imported variables
platform_lib_content = platform_lib_template.render(platform_lib_vars)
mbedtls_lib_content = mbedtls_lib_template.render(mbedtls_lib_vars)

# Output rendered files
slc_vars_yaml_dir: Path = Path(args.slc_vars_yaml).parent
platform_lib_output_file: Path = slc_vars_yaml_dir / "CMakeLists.txt"
with open(platform_lib_output_file, mode="w", encoding="utf-8") as message:
    message.write(platform_lib_content)
    print(f"... wrote {platform_lib_output_file}")
mbedtls_lib_output_file: Path = slc_vars_yaml_dir / "mbedtls.cmake"
with open(mbedtls_lib_output_file, mode="w", encoding="utf-8") as message:
    message.write(mbedtls_lib_content)
    print(f"... wrote {mbedtls_lib_output_file}")
